# coding:utf-8

from src.manager.oracle_manager import OracleManager
from src.manager.tree_manager import TreeManager
from src.constant.oracle import Oracle
from src.manager.log_manager import LogManager
from src.manager.file_manager import FileManager
from src.constant.file_and_path_constant import FileAndPathConstant
from src.config.date_config import DateConfig
from src.config.tree_config import TreeConfig
from numpy import *
import os
import shutil

Logger = LogManager.get_logger(__name__)


class TreeHandler:

    def __init__(self):
        """
        构造函数，初始化oracle_manager和cursor对象
        """
        self.oracle_manager = OracleManager(Oracle.Username, Oracle.Password, Oracle.Url)
        self.oracle_manager.connect()
        self.cursor = self.oracle_manager.get_cursor()

    def close_cursor_and_connect(self):
        """
        关闭游标和数据库连接
        :return:
        """
        self.cursor.close()
        self.oracle_manager.connect_close()

    def collect_and_store_data_in_file(self, characteristic_value_number_list):
        """
        查询特征向量和标签，并存储在文件中
        :param characteristic_value_number_list:
        :return:
        """

        for characteristic_value_number in characteristic_value_number_list:
            Logger.info('查询特征向量和标签，并存储在文件中，特征值数量为【' + str(characteristic_value_number) + '】')

            # 作为变量characteristic_value_number的备份变量
            characteristic_value_number_for_filename = str(characteristic_value_number)

            # 查找股票code
            self.cursor.execute('select distinct t.code_ from mdl_stock_month_analysis t')
            stock_code_tuple_list = self.cursor.fetchall()
            for stock_code_tuple in stock_code_tuple_list:
                stock_code = stock_code_tuple[0]
                Logger.info('开始收集股票【' + stock_code + '】的数据，特征值数量为【' + str(characteristic_value_number) + '】')

                # 查找每只股票，并按照升序排列
                self.cursor.execute("select * from mdl_stock_month_analysis t "
                                    "where t.begin_date >= to_date('" + DateConfig.Mdl_Stock_Month_Analysis_Begin_Date + "', 'yyyy-mm-dd') and t.end_date <= to_date('" + DateConfig.Mdl_Stock_Month_Analysis_End_Date + "', 'yyyy-mm-dd') "
                                                                                                                                                                                                           "and t.code_='" + stock_code + "' "
                                                                                                                                                                                                                                          "order by t.end_date asc")
                mdl_stock_month_analysis_tuple_list = self.cursor.fetchall()
                for index in range(len(mdl_stock_month_analysis_tuple_list)):
                    second_mdl_stock_month_analysis_tuple_list = mdl_stock_month_analysis_tuple_list[index:]

                    # 当表中记录数不够特征值+标签的数量时，跳转到上一个for
                    if len(second_mdl_stock_month_analysis_tuple_list) <= int(characteristic_value_number_for_filename):
                        break

                    # 存储特征值和特征向量。第一列为开始日期，第二列为结束日期，第三列为股票code，最后一列为标签，其余列为特征值
                    up_down_vector_list = list()

                    for second_mdl_stock_month_analysis_tuple in second_mdl_stock_month_analysis_tuple_list:
                        # 开始日期
                        if len(up_down_vector_list) == 0:
                            up_down_vector_list.append(
                                second_mdl_stock_month_analysis_tuple[1].strftime('%Y-%m-%d %H:%M:%S'))
                        # 结束日期
                        if len(up_down_vector_list) == 1:
                            up_down_vector_list.append(
                                second_mdl_stock_month_analysis_tuple[2].strftime('%Y-%m-%d %H:%M:%S'))
                        # 股票代码
                        if len(up_down_vector_list) == 2:
                            up_down_vector_list.append(second_mdl_stock_month_analysis_tuple[3])
                        # ma_trend
                        if len(up_down_vector_list) == 3:
                            up_down_vector_list.append(str(second_mdl_stock_month_analysis_tuple[4]))
                        # macd_trend
                        if len(up_down_vector_list) == 4:
                            up_down_vector_list.append(str(second_mdl_stock_month_analysis_tuple[5]))
                        # kd_trend
                        if len(up_down_vector_list) == 5:
                            up_down_vector_list.append(str(second_mdl_stock_month_analysis_tuple[6]))
                        # 特征值
                        if len(up_down_vector_list) >= 6:
                            up_down_vector_list.append(str(second_mdl_stock_month_analysis_tuple[7]))
                            characteristic_value_number = characteristic_value_number - 1

                        # 当up_down_vector_list中的数据收集好之后，保存到文件中
                        if len(up_down_vector_list) == (7 + int(characteristic_value_number_for_filename)):
                            FileManager.write(
                                FileAndPathConstant.Tree_Vector_File_Path + '/' + characteristic_value_number_for_filename + FileAndPathConstant.Tree_Training_Data_Path + '/' + stock_code + FileAndPathConstant.Tree_Vector_File_Extension_Name,
                                'a', ','.join(up_down_vector_list) + '\n')
                            characteristic_value_number = int(characteristic_value_number_for_filename)
                            break
        self.close_cursor_and_connect()

    def distinguish_training_data_and_testing_data(self):
        """
        从训练数据中挑选测试数据，并将其移动到testing目录中
        :return:
        """
        Logger.info('从训练数据中挑选测试数据，并将其移动到testing目录中')

        characteristic_value_path_list = os.listdir(FileAndPathConstant.Tree_Vector_File_Path)
        for vector_file_path in characteristic_value_path_list:
            vector_file_list = os.listdir(
                FileAndPathConstant.Tree_Vector_File_Path + '/' + vector_file_path + '/' + FileAndPathConstant.Tree_Training_Data_Path)

            # 计算应该有多少测试数据文件
            testing_data_file_number = round(len(vector_file_list) * TreeConfig.Testing_DataPercentage)
            # 随机地从训练数据中挑选测试数据
            while testing_data_file_number > 0:
                # 计算随机数
                random_number = random.randint(0, len(vector_file_list) - 1)
                # 拷贝文件
                shutil.copyfile(
                    FileAndPathConstant.Tree_Vector_File_Path + '/' + vector_file_path + '/' + FileAndPathConstant.Tree_Training_Data_Path + '/' +
                    vector_file_list[random_number],
                    FileAndPathConstant.Tree_Vector_File_Path + '/' + vector_file_path + '/' + FileAndPathConstant.Tree_Testing_Data_Path + '/' +
                    vector_file_list[random_number])
                # 删除文件
                os.remove(
                    FileAndPathConstant.Tree_Vector_File_Path + '/' + vector_file_path + '/' + FileAndPathConstant.Tree_Training_Data_Path + '/' +
                    vector_file_list[random_number])
                # 删除vector_file_list列表中索引为random_number的元素
                del vector_file_list[random_number]
                # 减一，以便继续迭代
                testing_data_file_number = testing_data_file_number - 1

    def do_tree(self, characteristic_value):
        """
        执行决策树算法
        :param characteristic_value:
        :return:
        """
        # 训练数据集
        training_group_ndarray, training_label_list = self.create_data_set(
            FileAndPathConstant.Tree_Vector_File_Path + '/' + str(
                characteristic_value) + FileAndPathConstant.Tree_Training_Data_Path)
        # 测试数据集
        testing_group_ndarray, testing_label_list = self.create_data_set(
            FileAndPathConstant.Tree_Vector_File_Path + '/' + str(
                characteristic_value) + FileAndPathConstant.Tree_Testing_Data_Path)

        # 将ndarray类型转换为list类型
        training_group_ndarray = training_group_ndarray.tolist()
        testing_group_ndarray = testing_group_ndarray.tolist()

        # 创建决策树
        tree_manager = TreeManager()
        my_tree = tree_manager.create_tree(training_group_ndarray, training_label_list)
        Logger.info('决策树：' + str(my_tree))

        # 使用决策树进行分类
        training_group_ndarray, training_label_list = self.create_data_set(FileAndPathConstant.Tree_Vector_File_Path + '/' + str(
                characteristic_value) + FileAndPathConstant.Tree_Training_Data_Path)  # 需要再给测试集和标签赋值，否则labels中的数据将不准
        # 判断正确的次数
        success_number = 0
        for index, testing_group_array in enumerate(testing_group_ndarray):
            result = tree_manager.classify(my_tree, training_label_list, testing_group_array)
            if result == testing_label_list[index]:
                success_number = success_number + 1
        Logger.info('特征值为【' + str(characteristic_value) + '】时，'
                    + '正确率：' + str(success_number / len(testing_group_ndarray) * 100))

    def create_data_set(self, path):
        """
        读取指定路径下的文件，并将特征值和标签分别存储在numpy.ndarray类型的数组和一维数组中。可以用于创建训练数据集或测试数据集
        :param path:
        :return:
        """
        Logger.info('读取指定路径下的文件，并将特征值和标签分别存储在numpy.ndarray类型的数组和一维数组中')

        # 特征值二维数组
        group_2d_list = []
        # 标签一维数组
        label_list = []
        file_list = os.listdir(path)
        # 分别存储特征值和标签
        for file in file_list:
            path_and_file = os.path.join(path, file)
            content = FileManager.read(path_and_file, 'r')
            row_list = content.split('\n')
            for row in row_list:
                if row != None and row != '':
                    # 表示group_2d_list中的一个元素（也是一维数组）
                    group = []
                    # 将一行切分为数组
                    column_list = row.split(',')
                    for i in range(len(column_list)):
                        # 存储特征值
                        if i > 2 and i != len(column_list) - 1:
                            group.append(int(column_list[i]))
                        # 存储标签
                        if i == len(column_list) - 1:
                            label_list.append(int(column_list[i]))
                    # 存储特征值
                    group_2d_list.append(group)
        return array(group_2d_list), label_list
